import torch
from diffusers.models.resnet import ResnetBlock2D
from diffusers.utils import deprecate

from .cache_base import CacheBase


class CustomResnetBlock2D(ResnetBlock2D, CacheBase):
    def __init__(self, use_cache=True, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.mode = 'invert'
        self.load_feature = False
        self.t = None

        # for attention cache
        self.use_cache = use_cache
        self.output_cache = None
        self.fusion_mode = "min" # "mean", "max" or "min"
        self.mean_factor = 1 / 10

    def set_cur_t(self, cur_t):
        self.t = cur_t

    def set_invert_or_generate(self, mode="invert"):
        self.mode = mode

    def set_load_mode(self, mode):
        self.load_feature = mode

    def reset_cache(self):
        del self.output_cache
        self.output_cache = None

    def save_cache_to_file(self):
        pass

    def load_cache_from_file(self):
        pass

    @torch.no_grad()
    def save_cache(self, output):
        if self.t >= 40:
            if self.fusion_mode == "mean":
                # self.output_cache = output.detach().clone() if self.output_cache is None else (1 - self.mean_factor) * self.output_cache + self.mean_factor * output.detach().clone()
                self.output_cache = self.mean_factor * output.detach().clone() if self.output_cache is None else self.output_cache + self.mean_factor * output.detach().clone()
            elif self.fusion_mode == "max":
                self.output_cache = torch.max(self.output_cache, output.detach().clone()) if self.output_cache is not None else output.detach().clone()
            elif self.fusion_mode == "min":
                self.output_cache = torch.min(self.output_cache, output.detach().clone()) if self.output_cache is not None else output.detach().clone()
            else:
                raise ValueError(f"Invalid fusion mode: {self.fusion_mode}. Choose from 'sum', 'max', or 'min'.")

    @torch.no_grad()
    def load_cache(self, B, device):
        # return self.output_cache.repeat(B, 1, 1, 1).to(device, non_blocking=True) if self.output_cache is not None else None
        return torch.cat([self.output_cache, self.output_cache]).to(device, non_blocking=True) if self.output_cache is not None else None

    def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        if len(args) > 0 or kwargs.get("scale", None) is not None:
            deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
            deprecate("scale", "1.0.0", deprecation_message)
        hidden_states = input_tensor

        hidden_states = self.norm1(hidden_states)
        hidden_states = self.nonlinearity(hidden_states)

        if self.upsample is not None:
            # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
            if hidden_states.shape[0] >= 64:
                input_tensor = input_tensor.contiguous()
                hidden_states = hidden_states.contiguous()
            input_tensor = self.upsample(input_tensor)
            hidden_states = self.upsample(hidden_states)
        elif self.downsample is not None:
            input_tensor = self.downsample(input_tensor)
            hidden_states = self.downsample(hidden_states)

        hidden_states = self.conv1(hidden_states)

        if self.time_emb_proj is not None:
            if not self.skip_time_act:
                temb = self.nonlinearity(temb)
            temb = self.time_emb_proj(temb)[:, :, None, None]

        if self.time_embedding_norm == "default":
            if temb is not None:
                hidden_states = hidden_states + temb
            hidden_states = self.norm2(hidden_states)
        elif self.time_embedding_norm == "scale_shift":
            if temb is None:
                raise ValueError(
                    f" `temb` should not be None when `time_embedding_norm` is {self.time_embedding_norm}"
                )
            time_scale, time_shift = torch.chunk(temb, 2, dim=1)
            hidden_states = self.norm2(hidden_states)
            hidden_states = hidden_states * (1 + time_scale) + time_shift
        else:
            hidden_states = self.norm2(hidden_states)

        hidden_states = self.nonlinearity(hidden_states)

        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)

        # if during the invert, save the hidden_states cache
        if not self.use_cache:
            pass
        elif self.mode == "invert":
            self.save_cache(hidden_states)
        elif self.mode == "generate":
            # if during the generate, load the hidden_states cache
            hidden_states[:int(input_tensor.shape[0] // 2), ] = hidden_states[int(input_tensor.shape[0] // 2):, ]
        else:
            raise ValueError(f"Invalid mode: {self.mode}. Choose from 'invert' or 'generate'.")

        if self.conv_shortcut is not None:
            input_tensor = self.conv_shortcut(input_tensor)

        output_tensor = (input_tensor + hidden_states) / self.output_scale_factor

        return output_tensor
